function [hat_beta, SE, reject, reject0, pval, pval0] = implement_test_cluster(dy, z, shares, shifters, critical , ols, cluster_vec)

%This script implements the test used in estimation (allowing for clusters of shifters)
%Corresponding author: Rodrigo Adao
%Date: 09/11/2024

Nobs = length(dy);

%Statistic estimation
if ols == 1
   %OLS version -- normalize covariance by variance of IV
   hat_beta = ( (z'*dy)/Nobs ) /( sum(z.^2)/Nobs );  
   hat_eta = dy - hat_beta*z;
else
   %Covariance version 
   hat_beta = (z'*dy)/Nobs ;
   hat_eta = dy - ( hat_beta/( sum(z.^2)/Nobs ) )*z;
end  
hat_eta0 = dy; 


%Compute variance term
if isempty(cluster_vec) == 1
%Shifters are iid
    %With residuals computed under the null
    R0 = hat_eta0'*shares / Nobs; 
    hatV0_beta = (R0.^2)*(shifters.^2); 
          
    %With residuals computed with estimated OLS coefficient
    R = hat_eta'*shares / Nobs; 
    hatV_beta = (R.^2)*(shifters.^2);
else
%Shifters are iid across clusters

    cluster_g = unique(cluster_vec);
    Nc = length(cluster_g);
    hatV_cluster  = zeros(Nc,1);
    hatV0_cluster = zeros(Nc,1);

    for c = 1:Nc       
        shares_cluster = shares(:,cluster_vec'==cluster_g(c));
        shifters_cluster = shifters(cluster_vec==cluster_g(c));

        RX0_cluster = (( hat_eta0'*shares_cluster  )').*shifters_cluster/Nobs;
        hatV0_cluster(c) = sum(sum(RX0_cluster*RX0_cluster'));

        RX_cluster = (( hat_eta'*shares_cluster  )').*shifters_cluster/Nobs;
        hatV_cluster(c) = sum(sum(RX_cluster*RX_cluster'));

    end
        hatV0_beta = sum(hatV0_cluster);
        hatV_beta = sum(hatV_cluster);
end

%Standard error and tests
SE  = ( hatV_beta.^(1/2)  );
SE0 = ( hatV0_beta.^(1/2) );
if ols == 1
    SE = SE/( sum(z.^2)/Nobs );
    SE0 = SE0/( sum(z.^2)/Nobs );
end

%rejection decision and p-value
tstat  = hat_beta/SE;
tstat0 = hat_beta/SE0;

reject  = abs(tstat)  > critical;
reject0 = abs(tstat0) > critical;

pval  = 2*(1 - normcdf(abs(tstat) ,0,1));
pval0 = 2*(1 - normcdf(abs(tstat0),0,1));

end